/*
 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.tribuo.multilabel.baseline;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Model;
import org.tribuo.MutableDataset;
import org.tribuo.Trainer;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.multilabel.ImmutableMultiLabelInfo;
import org.tribuo.multilabel.MultiLabel;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;

import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Map;

/**
 * Trains n independent binary {@link Model}s, each of which predicts a single {@link Label}.
 * <p>
 * Then wraps it up in an {@link IndependentMultiLabelModel} to provide a {@link MultiLabel}
 * prediction.
 * <p>
 * It trains each model sequentially, and could be optimised to train in parallel.
 */
public class IndependentMultiLabelTrainer implements Trainer<MultiLabel> {

    @Config(mandatory = true,description="Trainer to use for each individual label.")
    private Trainer<Label> innerTrainer;

    private int trainInvocationCounter = 0;

    /**
     * for olcut.
     */
    private IndependentMultiLabelTrainer() {}

    public IndependentMultiLabelTrainer(Trainer<Label> innerTrainer) {
        this.innerTrainer = innerTrainer;
    }

    @Override
    public Model<MultiLabel> train(Dataset<MultiLabel> examples, Map<String, Provenance> runProvenance) {
        if (examples.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        ImmutableMultiLabelInfo labelInfo = (ImmutableMultiLabelInfo) examples.getOutputIDInfo();
        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
        ArrayList<Model<Label>> modelsList = new ArrayList<>();
        ArrayList<Label> labelList = new ArrayList<>();
        DatasetProvenance datasetProvenance = examples.getProvenance();
        //TODO supply more suitable provenance showing it's a single dimension out of many.
        MutableDataset<Label> trainingData = new MutableDataset<>(datasetProvenance, new LabelFactory());
        for (MultiLabel l : labelInfo.getDomain()) {
            Label label = new Label(l.getLabelString());
            trainingData.clear();
            labelList.add(label);
            for (Example<MultiLabel> e : examples) {
                Label newLabel = e.getOutput().createLabel(label);
                // This sets the label in the new example to either l or MultiLabel.NEGATIVE_LABEL_STRING.
                trainingData.add(new BinaryExample(e,newLabel));
            }
            modelsList.add(innerTrainer.train(trainingData));
        }
        ModelProvenance provenance = new ModelProvenance(IndependentMultiLabelModel.class.getName(), OffsetDateTime.now(), datasetProvenance, getProvenance(), runProvenance);
        trainInvocationCounter++;
        return new IndependentMultiLabelModel(labelList,modelsList,provenance,featureMap,labelInfo);
    }

    @Override
    public int getInvocationCount() {
        return trainInvocationCounter;
    }

    @Override
    public String toString() {
        return "IndependentMultiLabelTrainer(innerTrainer="+innerTrainer.toString()+")";
    }

    @Override
    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl(this);
    }
}

